function solution                       = solve(pars,Nu,loud)

tt_total = tic;

if loud
fprintf('\n');
end;

%save('runs/solve_environment.mat')

% Reduced-form parameters        
pars.A_M                                = 1/(pars.r+pars.k+pars.theta);
pars.A_X                                = 1/(pars.r+2*pars.k)*pars.C;
pars.A_N                                = pars.k*pars.A_X;
pars.uPhi0                              = pars.M_c -(pars.r+pars.k+pars.theta)/(pars.r+pars.k)*(pars.M_c - pars.M_e);           
pars.oPhi1                              = pars.uPhi0 + pars.A_X/pars.A_M*( 1  + (pars.k)/(pars.r+pars.k) ); 
pars.xi                                 = (pars.r+pars.k)/(pars.r+pars.k+pars.theta)*pars.sigma;
M_e_ub                                  = pars.M_c + pars.xi*sqrt(Nu.N/pars.theta) - pars.C;
M_e_lb                                  = pars.M_c - pars.xi*sqrt(Nu.N/pars.theta);
C_ub                                    = pars.M_c + pars.xi*sqrt(Nu.N/pars.theta) - pars.M_e;
if ~( pars.M_e >= M_e_lb-Nu.tol_solve & pars.M_e <= M_e_ub+Nu.tol_solve & M_e_lb <= M_e_ub )
        fprintf('Conditions on M grid not satisfied \n');
        fprintf('Delta t = %5.4f, 1/(Nu.N*Nu.theta_target) = %5.4f \n',Nu.Deltat,Nu.N*Nu.theta_target);
        fprintf('M_e = %.6f, M_e_lb = %.6f \n',pars.M_e,M_e_lb);
        fprintf('M_e = %.6f, M_e_ub = %.6f \n',pars.M_e,M_e_ub);
        fprintf('(M_e - M_e_ub)/tol_solve = %.6f\n',(pars.M_e-M_e_ub)/Nu.tol_solve);
        fprintf('C = %.6f, C_ub = %.6f \n',pars.C,C_ub);
        error('Stop here');
end;

if loud
fprintf('Model parameters \n');
fprintf('r                              = %5.3f \n',pars.r);
fprintf('k                              = %5.3f \n',pars.k);
fprintf('C                              = %5.3f \n',pars.C);
fprintf('M_e                            = %5.3f \n',pars.M_e);
fprintf('M_c                            = %5.3f \n',pars.M_c);
fprintf('sigma                          = %5.3f \n',pars.sigma);
fprintf('theta                          = %5.3f \n',pars.theta);
fprintf('T                              = %5.0f months \n',pars.T);
fprintf('\n');       
end
  
% Discretization parameters
Nu.X                                    = linspace(0,1,Nu.NX);                                                                                          % Grid for X
Nu.hX                                   = min(diff(Nu.X));                                                                                              % Step size on the grid for X
Nu.NM                                   = 2*Nu.N+1;                                                                                                     % Number of gridpoints for M
Nu.M_min                                = pars.M_c - sqrt(Nu.N/Nu.theta_target)*pars.sigma;                                                             % Lower bound for M grid
Nu.M_max                                = pars.M_c + sqrt(Nu.N/Nu.theta_target)*pars.sigma;                                                             % Upper boudn for M grid
Nu.M                                    = linspace(Nu.M_min,Nu.M_max,Nu.NM)';                                                                           % Grid for M
Nu.hM                                   = pars.sigma/(sqrt(Nu.N*Nu.theta_target));                                                                      % Step size on the grid for M
Nu.Xgrid                                = repmat(Nu.X,[Nu.NM,1]);                                                                                       % Meshed grid for X
Nu.Mgrid                                = repmat(Nu.M,[1,Nu.NX]);                                                                                       % Meshed grid for M
Nu.Deltat                               = 1/(Nu.N*Nu.theta_target);                                                                                     % Time step, fraction of a month
Nu.tol                                  = (pars.r+pars.k)*Nu.Deltat*1e-4;
Nu.tBI                                  = pars.T;       

if loud
fprintf('Discretization parameters \n');
fprintf('theta_target                   = %5.3f \n',Nu.theta_target);
fprintf('NX                             = %5.0f \n',Nu.NX);
fprintf('hX                             = %5.3f \n',Nu.hX);
fprintf('N                              = %5.0f \n',Nu.N);
fprintf('N_M                            = %5.0f \n',Nu.NM);
fprintf('M_min                          = %5.3f \n',Nu.M(1));
fprintf('M_max                          = %5.3f \n',Nu.M(end));
fprintf('uPhi(0)                        = %5.3f \n',pars.uPhi0);
fprintf('oPhi(1)                        = %5.3f \n',pars.oPhi1);
fprintf('dt                             = %5.3f months, = %5.3f days \n',Nu.Deltat,Nu.Deltat*30);
fprintf('\n');
end

if loud
fprintf('Condition on X grid size \n');
fprintf('min(1/(r+2k),1/(N_X*k))        = %5.3f \n',min(1/(pars.r+2*pars.k),1/(Nu.NX*pars.k)));
fprintf('dt                             = %5.3f \n',Nu.Deltat);
if Nu.Deltat <= min(1/(pars.r+2*pars.k),1/(Nu.NX*pars.k))
        fprintf('Condition on X grid size satisfied \n');
else
        fprintf('Condition on X grid size not satisfied \n');
        error('Stop here');
end;
fprintf('\n');
else
if Nu.Deltat > min(1/(pars.r+2*pars.k),1/(Nu.NX*pars.k))
        fprintf('Delta t = %5.4f, min(1/(pars.r+2*pars.k),1/(Nu.NX*pars.k)) = %5.4f \n',Nu.Deltat,min(1/(pars.r+2*pars.k),1/(Nu.NX*pars.k)));
        error('Stop here');
end;
end;

if loud
fprintf('Conditions on M grid size \n');
fprintf('1/(N_M*theta_target)           = %5.3f \n',1/(Nu.NM*Nu.theta_target));
fprintf('dt                             = %5.3f \n',Nu.Deltat);
if Nu.Deltat <= 1/(Nu.N*Nu.theta_target) & Nu.M_min <= pars.uPhi0 & Nu.M_max >= pars.oPhi1
        fprintf('Conditions on M grid satisfied \n');
        error('Stop here');
else
        fprintf('Conditions on M grid not satisfied \n');
        error('Stop here');
end;
fprintf('\n');
else
if ~( Nu.Deltat <= 1/(Nu.N*Nu.theta_target) & pars.M_e >= M_e_lb-Nu.tol_solve & pars.M_e <= M_e_ub+Nu.tol_solve )
        fprintf('Conditions on M grid not satisfied \n');
        fprintf('Delta t = %5.4f, 1/(Nu.N*Nu.theta_target) = %5.4f \n',Nu.Deltat,Nu.N*Nu.theta_target);
        fprintf('M_e = %.6f, M_e_lb = %.6f \n',pars.M_e,M_e_lb);
        fprintf('M_e = %.6f, M_e_ub = %.6f \n',pars.M_e,M_e_ub);
        error('Stop here');
end;
end;

%%%%%%%%%%%%%%%%%
% Solution setup
%%%%%%%%%%%%%%%%%

% Transition matrices for X
dX_m                                    = -pars.k*Nu.Deltat*Nu.X;
J_m                                     = diag(ones(Nu.NX,1),0) + (Nu.hX)^(-1)*diag(dX_m,0) + (Nu.hX)^(-1)*diag(-dX_m(2:end),1);
dX_p                                    = pars.k*Nu.Deltat*(1-Nu.X);
J_p                                     = diag(ones(Nu.NX,1),0) + (Nu.hX)^(-1)*diag(-dX_p,0) + (Nu.hX)^(-1)*diag(dX_p(1:(end-1)),-1);

% Transition matrices for M
u                                       = (pars.sigma^2 + Nu.hM*pars.theta*(pars.M_c - Nu.M))/(2*Nu.hM^2)*Nu.Deltat;
d                                       = (pars.sigma^2 - Nu.hM*pars.theta*(pars.M_c - Nu.M))/(2*Nu.hM^2)*Nu.Deltat; 
if pars.theta == 0
   u(1)                                        = 0;
   d(1)                                        = 0;
   u(end)                                      = 0;
   d(end)                                      = 0;     
end
J_M                                     = diag(u(1:(end-1)),1) + diag(d(2:end),-1) + diag(1-u-d,0);
u                                       = 1/2;
d                                       = 1/2; 
if pars.theta == 0
   u(1)                                        = 0;
   d(1)                                        = 0;
   u(end)                                      = 0;
   d(end)                                      = 0;     
end
J_M_LR                                  = diag(u(1:(end-1)),1) + diag(d(2:end),-1) + diag(1-u-d,0);

% Flow profits
Flow                                    = (pars.M_e - Nu.Mgrid + pars.C*Nu.Xgrid)*Nu.Deltat;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Step 1: Solution for t > T
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

tt = tic;

% Compute solution
if pars.theta > 0
Nu.theta_target                         = pars.theta;
end
pars.theta                              = 0;
loud_local                              = 0;
solution.postT                          = solve_idsds(pars,Nu,loud_local);
pars.theta                              = Nu.theta_target;

% Check convergence of IDSDS for t > T
if max(abs(solution.postT.Phi_above-solution.postT.Phi_above)) > Nu.tol
        error('Error: IDSDS does not converge for t > T \n');
end

% Record solution
solution.postT.B                        = solution.postT.B_below;
solution.postT.Phi                      = solution.postT.Phi_below;

tt = toc(tt);
if loud
fprintf('Time for t>T solution:                 %d minutes and %4.2f seconds\n', floor(tt/60), rem(tt,60));
end;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Step 2: Solution for t <= T
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

tt = tic;

% Horizon of backward induction
Nt                                      = ceil(Nu.tBI/Nu.Deltat);

% Matrices to store output from backward iteration
[B_preT,Ap_preT,Am_preT]                = deal(NaN(Nu.NM,Nu.NX,Nt+1));
B_preT(:,:,end)                         = solution.postT.B;
[AM_preT,uM_preT,crit_preT]             = deal(NaN(1,Nt+1));
crit_preT(:,end)                        = 0;
AM_preT(:,end)                          = 1/(pars.r+pars.k);
uM_preT(:,end)                          = pars.M_e;
[Phi_best_preT,Phi_worse_preT]          = deal(NaN(1,Nu.NX,Nt+1));
Phi_best_preT(1,:,end)                  = solution.postT.Phi_best;
Phi_worse_preT(1,:,end)                 = solution.postT.Phi_worse;

% Backward induction
for nt=1:Nt
        
        B_next                                  = B_preT(:,:,Nt+2-nt);
        Ap                                      = B_next >= 0;
        Am                                      = ones(Nu.NM,Nu.NX) - Ap;
        B_preT(:,:,Nt+1-nt)                     = Flow + (1-(pars.r+pars.k)*Nu.Deltat)*( Ap.*(J_M*B_next*J_p) + Am.*(J_M*B_next*J_m) );
        Ap_preT(:,:,Nt+1-nt)                    = Ap;
        Am_preT(:,:,Nt+1-nt)                    = Am;
        crit_preT(Nt+1-nt)                      = max(max(abs(B_preT(:,:,Nt+1-nt)-B_next)));

        %AM_preT(Nt+1-nt)                        = 1/(pars.r+pars.k+pars.theta) + 1/((pars.r+pars.k)*(pars.r+pars.k+pars.theta))*pars.theta*exp(-(pars.r+pars.k+pars.theta)*nt*Nu.Deltat);
        %uM_preT(Nt+1-nt)                        = pars.M_c - 1/(AM_preT(Nt+1-nt)*(pars.r+pars.k))*( pars.M_c - pars.M_e );
        %Phi_worse_preT(1,:,Nt+1-nt)             = uM_preT(Nt+1-nt) + pars.A_X/AM_preT(Nt+1-nt)*Nu.X;
        %Phi_best_preT(1,:,Nt+1-nt)              = Phi_worse_preT(1,:,Nt+1-nt) + pars.A_X/AM_preT(Nt+1-nt)*(pars.k/(pars.r+pars.k));

end;

% Adoption threshold at time 0
B_0                                     = B_preT(:,:,1);
Phi_0                                   = NaN(Nu.NX,1);
for nx=1:Nu.NX
        if B_0(1,nx) <= 0
                Phi_0(nx)                       = Nu.M(1);
                error('Min value of M is too high');
        else if B_0(end,nx) >= 0
                Phi_0(nx)                       = Nu.M(end);
                error('Max value of M is too low');
                else 
                        nm                              = find( B_0(:,nx) < 0,1,'first' );
                        if isempty(nm) || (nm == 1)
                                error('Unclear error with the threshold');
                        else
                                Phi_0(nx)                       = Nu.M(nm-1) - Nu.hM/(B_0(nm,nx)-B_0(nm-1,nx))*B_0(nm-1,nx);
                        end;
                end;
        end;
end

tt = toc(tt);
if loud
fprintf('Time for t<T solution:                 %d minutes and %4.2f seconds\n', floor(tt/60), rem(tt,60));
end;

% Record solution
solution.preT.B_0                       = B_preT(:,:,1);
solution.preT.Phi_0                     = Phi_0;
solution.preT.B                         = B_preT;
%solution.preT.Phi_best_0                = Phi_best_preT(1,:,1);
%solution.preT.Phi_worse_0               = Phi_worse_preT(1,:,1);
%solution.preT.Phi_best                  = Phi_best_preT;
%solution.preT.Phi_worse                 = Phi_worse_preT;
solution.preT.Ap                        = Ap_preT;
solution.preT.Am                        = Am_preT;
solution.preT.J_M                       = J_M;
solution.preT.J_m                       = J_m;
solution.preT.J_p                       = J_p;
solution.preT.crit                      = crit_preT;
solution.preT.Nu                        = Nu;
solution.preT.pars                      = pars;

if Nu.solutionmode > 0

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Step 3: IRFs
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

tt = tic;

% Horizon of IRFs
Nt                                      = ceil(pars.tIRF/Nu.Deltat);
if pars.tIRF > pars.T - Nu.Deltat
        error('IRF horizon exceeds T');
end;

% Matrices to store IRFs
[IRF_M,IRF_X,IRF_a]                     = deal(NaN(Nu.NM,Nu.NX,Nt+1));
IRF_M(:,:,1)                            = repmat(Nu.M,[1 Nu.NX]);
IRF_X(:,:,1)                            = repmat(Nu.X,[Nu.NM 1]);
IRF_a(:,:,1)                            = solution.preT.B_0 >= 0;
for nt=1:Nt

   if crit_preT(nt) > Nu.tol
        error('IRF horizon too close to T');
   end;

   IRF_M(:,:,nt+1)                              = Ap_preT(:,:,nt).*(J_M*IRF_M(:,:,nt)*J_p) + Am_preT(:,:,nt).*(J_M*IRF_M(:,:,nt)*J_m);
   IRF_X(:,:,nt+1)                              = Ap_preT(:,:,nt).*(J_M*IRF_X(:,:,nt)*J_p) + Am_preT(:,:,nt).*(J_M*IRF_X(:,:,nt)*J_m);
   IRF_a(:,:,nt+1)                              = Ap_preT(:,:,nt).*(J_M*IRF_a(:,:,nt)*J_p) + Am_preT(:,:,nt).*(J_M*IRF_a(:,:,nt)*J_m);

end;

%IRF_X                                  = IRF_X - repmat(IRF_X(:,:,1),[1 1 Nt+1]);
%IRF_M                                   = IRF_M - repmat(pars.M_c*ones(size(IRF_M(:,:,1))),[1 1 Nt+1]);

% Record IRFs
solution.IRF.Nt                         = Nt;
solution.IRF.tvec                       = 0:Nu.Deltat:(Nt*Nu.Deltat);
solution.IRF.IRF_M                      = IRF_M; 
solution.IRF.IRF_X                      = IRF_X;
solution.IRF.IRF_a                      = IRF_a; 

tt = toc(tt);
if loud
fprintf('Time for IRF computation:              %d minutes and %4.2f seconds\n', floor(tt/60), rem(tt,60));
end;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Step 4: "Stationary" distribution 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

tt = tic;

% Joint stationary distribution of (X,M)
Ap                                      = solution.preT.Ap(:,:,1);
spAp                                    = spdiags(Ap(:),0,Nu.NM*Nu.NX,Nu.NM*Nu.NX);
Am                                      = solution.preT.Am(:,:,1);
spAm                                    = spdiags(Am(:),0,Nu.NM*Nu.NX,Nu.NM*Nu.NX);
spJp                                    = sparse(J_p');
spJm                                    = sparse(J_m');
spJM                                    = sparse(J_M);
spL                                     = spAp*(kron(spJp,spJM)) + spAm*(kron(spJm,spJM));
if max(abs( sum(spL,2) - 1 )) > 1e-10
                error('Not a Markov matrix');
end;
%spL                                     = spL./repmat(sum(spL,2),[1 Nu.NX*Nu.NM]);
[v,d]                                   = eigs(spL',1,'largestreal');
if abs(d-1) > 1e-10
        error('No stationary distribution \n');
end
for n=1:numel(v)
        if abs(v(n)) < 1e-10
                v(n) = 0;
        end;
end;
if max(v) <= 0
        v                                       = -v;
end
v                                       = v/norm(v,1);
solution.preT.mu                        = reshape(v,[Nu.NM,Nu.NX]);
solution.preT.spL                       = spL;

% Exact stationary distribution of M
solution.preT.mu_exact                  = normpdf(Nu.M,pars.M_c,pars.sigma/sqrt(2*pars.theta));
solution.preT.mu_exact                  = solution.preT.mu_exact/norm(solution.preT.mu_exact,1);

tt = toc(tt);
if loud
fprintf('Time for stat distro:                  %d minutes and %4.2f seconds\n', floor(tt/60), rem(tt,60));
end;

if loud
fprintf('\n');
end;

else if Nu.mc_method > 0

        % Markov transition matrix
        Ap                                      = solution.preT.Ap(:,:,1);
        spAp                                    = spdiags(Ap(:),0,Nu.NM*Nu.NX,Nu.NM*Nu.NX);
        Am                                      = solution.preT.Am(:,:,1);
        spAm                                    = spdiags(Am(:),0,Nu.NM*Nu.NX,Nu.NM*Nu.NX);
        spJp                                    = sparse(J_p');
        spJm                                    = sparse(J_m');
        spJM                                    = sparse(J_M);
        spL                                     = spAp*(kron(spJp,spJM)) + spAm*(kron(spJm,spJM));
        if max(abs( sum(spL,2) - 1 )) > 1e-10
                        error('Not a Markov matrix');
        end;
        solution.preT.spL                       = spL;
        solution.IRF.tvec                       = 0:Nu.Deltat:(Nt*Nu.Deltat);

   else

        solution.IRF.tvec                       = 0:Nu.Deltat:(Nt*Nu.Deltat);

   end;

end;

tt_total = toc(tt_total);

if loud
fprintf('Total time:                            %d minutes and %4.2f seconds\n', floor(tt_total/60), rem(tt_total,60));
end;

end
